{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 手动搭建线性模型\n",
    "此节内容来源于 [徒手实现发向传播算法——分布式训练、GPU运算等](https://www.bilibili.com/video/BV1Xm42157Jr) 。\n",
    "\n",
    "在上一节中，我们简单实现了反向传播的模型，但是并没有用进行搭建神经网络，我们撰写的代码逻辑实际上和 PyTorch 是差不多的，所以这一节我们用自己的工具来实现神经网络，并且加上可视化的工具。\n",
    "\n",
    "本节的内容包括：\n",
    "- 最优化算法\n",
    "- 计算图膨胀\n",
    "- 梯度积累\n",
    "- 参数冻结\n",
    "- 随机失活 (dropout)\n",
    "\n",
    "下面的附加内容请移步B站视频学习：\n",
    "- GPU运算\n",
    "- 混合精度"
   ],
   "id": "9490e34053b73626"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T04:14:44.249688Z",
     "start_time": "2025-08-12T04:14:42.741809Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from tools.ch07_autograd.utils import Scalar, draw_graph\n",
    "from tools.ch07_autograd.linear_model import Linear, mse\n",
    "\n",
    "import torch\n",
    "\n",
    "torch.manual_seed(1024)  # 使结果可以复现"
   ],
   "id": "c2833119298cb754",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x182fffab510>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T04:14:44.258798Z",
     "start_time": "2025-08-12T04:14:44.253781Z"
    }
   },
   "cell_type": "code",
   "source": [
    "x = torch.linspace(100, 300, 200)\n",
    "x = (x - torch.mean(x)) / torch.std(x)\n",
    "epsilon = torch.randn(x.shape)\n",
    "y = 10 * x + 5 + epsilon"
   ],
   "id": "7944a47d85dc46e5",
   "outputs": [],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T04:14:44.521980Z",
     "start_time": "2025-08-12T04:14:44.382123Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model = Linear()\n",
    "\n",
    "batch_size = 20\n",
    "learning_rate = 0.1\n",
    "\n",
    "for t in range(20):\n",
    "    ix = (t * batch_size) % len(x)\n",
    "    xx = x[ix: ix + batch_size]\n",
    "    yy = y[ix: ix + batch_size]\n",
    "\n",
    "    # 计算均方误差，虽然我们的代码在训练模型的时候没有用到 Pytorch ，但是我们的测试数据是由 Pytorch 生成的，\n",
    "    # 均方误差需要考虑的就是本批次内所有东西的误差，这样写是与我们自定义实现的 mse 函数有关。\n",
    "    loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n",
    "    loss.backward()\n",
    "\n",
    "    model.a -= learning_rate * model.a.grad\n",
    "    model.b -= learning_rate * model.b.grad\n",
    "\n",
    "    model.a.grad = 0.0\n",
    "    model.b.grad = 0.0\n",
    "\n",
    "    print(model.string())"
   ],
   "id": "e2d68f524d6ab030",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y = 3.12 * x + -1.99\n",
      "y = 3.48 * x + -2.28\n",
      "y = 3.22 * x + -1.97\n",
      "y = 2.85 * x + -1.22\n",
      "y = 2.68 * x + -0.23\n",
      "y = 2.92 * x + 1.08\n",
      "y = 3.74 * x + 2.61\n",
      "y = 5.07 * x + 4.15\n",
      "y = 6.73 * x + 5.52\n",
      "y = 8.22 * x + 6.48\n",
      "y = 9.36 * x + 5.75\n",
      "y = 9.75 * x + 5.42\n",
      "y = 9.88 * x + 5.28\n",
      "y = 9.89 * x + 5.26\n",
      "y = 9.89 * x + 5.20\n",
      "y = 9.88 * x + 5.18\n",
      "y = 9.88 * x + 5.17\n",
      "y = 9.84 * x + 5.14\n",
      "y = 9.86 * x + 5.15\n",
      "y = 9.94 * x + 5.21\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T04:14:45.314261Z",
     "start_time": "2025-08-12T04:14:44.531452Z"
    }
   },
   "cell_type": "code",
   "source": [
    "### 计算图膨胀\n",
    "model = Linear()\n",
    "\n",
    "# 定义两组数据\n",
    "x1 = Scalar(1.5, label='x1', requires_grad=False)\n",
    "y1 = Scalar(1.0, label='y1', requires_grad=False)\n",
    "x2 = Scalar(2.0, label='x2', requires_grad=False)\n",
    "y2 = Scalar(4.0, label='y2', requires_grad=False)\n",
    "\n",
    "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n",
    "loss.backward()\n",
    "\n",
    "draw_graph(loss, 'backward')"
   ],
   "id": "c2e08ae5fb5ceb20",
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.1.0 (20240811.2233)\n -->\n<!-- Pages: 1 -->\n<svg width=\"307pt\" height=\"512pt\"\n viewBox=\"0.00 0.00 307.13 511.75\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 507.75)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-507.75 303.13,-507.75 303.13,4 -4,4\"/>\n<!-- 1660817535056backward -->\n<g id=\"node1\" class=\"node\">\n<title>1660817535056backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M86.63,-111.75C86.63,-111.75 125.13,-111.75 125.13,-111.75 131.13,-111.75 137.13,-117.75 137.13,-123.75 137.13,-123.75 137.13,-157.5 137.13,-157.5 137.13,-163.5 131.13,-169.5 125.13,-169.5 125.13,-169.5 86.63,-169.5 86.63,-169.5 80.63,-169.5 74.63,-163.5 74.63,-157.5 74.63,-157.5 74.63,-123.75 74.63,-123.75 74.63,-117.75 80.63,-111.75 86.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"105.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"74.63,-150.25 137.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"105.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"74.63,-131 137.13,-131\"/>\n<text text-anchor=\"middle\" x=\"105.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660817709648backward -->\n<g id=\"node2\" class=\"node\">\n<title>1660817709648backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"105.88\" cy=\"-29.38\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"105.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">x1=1.50</text>\n</g>\n<!-- 1660817535056backward&#45;&gt;1660817709648backward -->\n<g id=\"edge4\" class=\"edge\">\n<title>1660817535056backward&#45;&gt;1660817709648backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M105.88,-111.49C105.88,-91.39 105.88,-64.96 105.88,-47.83\"/>\n</g>\n<!-- 1660817526672backward -->\n<g id=\"node12\" class=\"node\">\n<title>1660817526672backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-0.5 155.63,-58.25 218.13,-58.25 218.13,-0.5 155.63,-0.5\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-44.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;9.50</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-39 218.13,-39\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-19.75 218.13,-19.75\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-6.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">a</text>\n</g>\n<!-- 1660817535056backward&#45;&gt;1660817526672backward -->\n<g id=\"edge12\" class=\"edge\">\n<title>1660817535056backward&#45;&gt;1660817526672backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M126.74,-111.49C136.66,-98.11 148.66,-81.93 159.32,-67.55\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"162.05,-69.75 165.19,-59.63 156.42,-65.58 162.05,-69.75\"/>\n<text text-anchor=\"middle\" x=\"165.48\" y=\"-79.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.50</text>\n</g>\n<!-- 1660817533392backward -->\n<g id=\"node3\" class=\"node\">\n<title>1660817533392backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M248.63,-111.75C248.63,-111.75 287.13,-111.75 287.13,-111.75 293.13,-111.75 299.13,-117.75 299.13,-123.75 299.13,-123.75 299.13,-157.5 299.13,-157.5 299.13,-163.5 293.13,-169.5 287.13,-169.5 287.13,-169.5 248.63,-169.5 248.63,-169.5 242.63,-169.5 236.63,-163.5 236.63,-157.5 236.63,-157.5 236.63,-123.75 236.63,-123.75 236.63,-117.75 242.63,-111.75 248.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"267.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"236.63,-150.25 299.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"267.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"236.63,-131 299.13,-131\"/>\n<text text-anchor=\"middle\" x=\"267.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660817531472backward -->\n<g id=\"node4\" class=\"node\">\n<title>1660817531472backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"267.88\" cy=\"-29.38\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"267.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">x2=2.00</text>\n</g>\n<!-- 1660817533392backward&#45;&gt;1660817531472backward -->\n<g id=\"edge2\" class=\"edge\">\n<title>1660817533392backward&#45;&gt;1660817531472backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M267.88,-111.49C267.88,-91.39 267.88,-64.96 267.88,-47.83\"/>\n</g>\n<!-- 1660817533392backward&#45;&gt;1660817526672backward -->\n<g id=\"edge6\" class=\"edge\">\n<title>1660817533392backward&#45;&gt;1660817526672backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M247.02,-111.49C237.1,-98.11 225.11,-81.93 214.45,-67.55\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"217.34,-65.58 208.57,-59.63 211.72,-69.75 217.34,-65.58\"/>\n<text text-anchor=\"middle\" x=\"246.48\" y=\"-79.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;8.00</text>\n</g>\n<!-- 1660817536208backward -->\n<g id=\"node5\" class=\"node\">\n<title>1660817536208backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M127.63,-445.5C127.63,-445.5 166.13,-445.5 166.13,-445.5 172.13,-445.5 178.13,-451.5 178.13,-457.5 178.13,-457.5 178.13,-491.25 178.13,-491.25 178.13,-497.25 172.13,-503.25 166.13,-503.25 166.13,-503.25 127.63,-503.25 127.63,-503.25 121.63,-503.25 115.63,-497.25 115.63,-491.25 115.63,-491.25 115.63,-457.5 115.63,-457.5 115.63,-451.5 121.63,-445.5 127.63,-445.5\"/>\n<text text-anchor=\"middle\" x=\"146.88\" y=\"-489.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"115.63,-484 178.13,-484\"/>\n<text text-anchor=\"middle\" x=\"146.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=8.50</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"115.63,-464.75 178.13,-464.75\"/>\n<text text-anchor=\"middle\" x=\"146.88\" y=\"-451.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">mse</text>\n</g>\n<!-- 1660817532688backward -->\n<g id=\"node7\" class=\"node\">\n<title>1660817532688backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M87.63,-334.25C87.63,-334.25 126.13,-334.25 126.13,-334.25 132.13,-334.25 138.13,-340.25 138.13,-346.25 138.13,-346.25 138.13,-380 138.13,-380 138.13,-386 132.13,-392 126.13,-392 126.13,-392 87.63,-392 87.63,-392 81.63,-392 75.63,-386 75.63,-380 75.63,-380 75.63,-346.25 75.63,-346.25 75.63,-340.25 81.63,-334.25 87.63,-334.25\"/>\n<text text-anchor=\"middle\" x=\"106.88\" y=\"-378.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"75.63,-372.75 138.13,-372.75\"/>\n<text text-anchor=\"middle\" x=\"106.88\" y=\"-359.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"75.63,-353.5 138.13,-353.5\"/>\n<text text-anchor=\"middle\" x=\"106.88\" y=\"-340\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">&#45;</text>\n</g>\n<!-- 1660817536208backward&#45;&gt;1660817532688backward -->\n<g id=\"edge3\" class=\"edge\">\n<title>1660817536208backward&#45;&gt;1660817532688backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M136.58,-445.24C131.87,-432.37 126.21,-416.91 121.1,-402.96\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"124.39,-401.77 117.67,-393.59 117.82,-404.18 124.39,-401.77\"/>\n<text text-anchor=\"middle\" x=\"141.28\" y=\"-413.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\n</g>\n<!-- 1660817530192backward -->\n<g id=\"node10\" class=\"node\">\n<title>1660817530192backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M168.63,-334.25C168.63,-334.25 207.13,-334.25 207.13,-334.25 213.13,-334.25 219.13,-340.25 219.13,-346.25 219.13,-346.25 219.13,-380 219.13,-380 219.13,-386 213.13,-392 207.13,-392 207.13,-392 168.63,-392 168.63,-392 162.63,-392 156.63,-386 156.63,-380 156.63,-380 156.63,-346.25 156.63,-346.25 156.63,-340.25 162.63,-334.25 168.63,-334.25\"/>\n<text text-anchor=\"middle\" x=\"187.88\" y=\"-378.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"156.63,-372.75 219.13,-372.75\"/>\n<text text-anchor=\"middle\" x=\"187.88\" y=\"-359.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"156.63,-353.5 219.13,-353.5\"/>\n<text text-anchor=\"middle\" x=\"187.88\" y=\"-340\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">&#45;</text>\n</g>\n<!-- 1660817536208backward&#45;&gt;1660817530192backward -->\n<g id=\"edge7\" class=\"edge\">\n<title>1660817536208backward&#45;&gt;1660817530192backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M157.44,-445.24C162.32,-432.24 168.19,-416.6 173.46,-402.54\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"176.59,-404.17 176.83,-393.58 170.04,-401.71 176.59,-404.17\"/>\n<text text-anchor=\"middle\" x=\"181.84\" y=\"-413.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.00</text>\n</g>\n<!-- 1660817523984backward -->\n<g id=\"node6\" class=\"node\">\n<title>1660817523984backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"29.88\" cy=\"-251.88\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"29.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">y1=1.00</text>\n</g>\n<!-- 1660817532688backward&#45;&gt;1660817523984backward -->\n<g id=\"edge14\" class=\"edge\">\n<title>1660817532688backward&#45;&gt;1660817523984backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M87.05,-333.99C72.44,-313.25 53.08,-285.8 41.06,-268.74\"/>\n</g>\n<!-- 1660817534224backward -->\n<g id=\"node8\" class=\"node\">\n<title>1660817534224backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M89.63,-223C89.63,-223 128.13,-223 128.13,-223 134.13,-223 140.13,-229 140.13,-235 140.13,-235 140.13,-268.75 140.13,-268.75 140.13,-274.75 134.13,-280.75 128.13,-280.75 128.13,-280.75 89.63,-280.75 89.63,-280.75 83.63,-280.75 77.63,-274.75 77.63,-268.75 77.63,-268.75 77.63,-235 77.63,-235 77.63,-229 83.63,-223 89.63,-223\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-267.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"77.63,-261.5 140.13,-261.5\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"77.63,-242.25 140.13,-242.25\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-228.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">+</text>\n</g>\n<!-- 1660817532688backward&#45;&gt;1660817534224backward -->\n<g id=\"edge10\" class=\"edge\">\n<title>1660817532688backward&#45;&gt;1660817534224backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M107.4,-333.99C107.63,-321.25 107.91,-305.97 108.16,-292.12\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"111.66,-292.49 108.34,-282.43 104.66,-292.36 111.66,-292.49\"/>\n<text text-anchor=\"middle\" x=\"122.25\" y=\"-302.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660817534224backward&#45;&gt;1660817535056backward -->\n<g id=\"edge5\" class=\"edge\">\n<title>1660817534224backward&#45;&gt;1660817535056backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M108.11,-222.74C107.76,-210 107.34,-194.72 106.96,-180.87\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"110.47,-181.08 106.69,-171.18 103.47,-181.27 110.47,-181.08\"/>\n<text text-anchor=\"middle\" x=\"121.81\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660817523472backward -->\n<g id=\"node9\" class=\"node\">\n<title>1660817523472backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-111.75 155.63,-169.5 218.13,-169.5 218.13,-111.75 155.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;5.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-150.25 218.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"155.63,-131 218.13,-131\"/>\n<text text-anchor=\"middle\" x=\"186.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">b</text>\n</g>\n<!-- 1660817534224backward&#45;&gt;1660817523472backward -->\n<g id=\"edge1\" class=\"edge\">\n<title>1660817534224backward&#45;&gt;1660817523472backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M128.97,-222.74C138.52,-209.36 150.07,-193.18 160.34,-178.8\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"163.02,-181.07 165.98,-170.9 157.32,-177 163.02,-181.07\"/>\n<text text-anchor=\"middle\" x=\"166.8\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660817533200backward -->\n<g id=\"node11\" class=\"node\">\n<title>1660817533200backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"187.88\" cy=\"-251.88\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"187.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">y2=4.00</text>\n</g>\n<!-- 1660817530192backward&#45;&gt;1660817533200backward -->\n<g id=\"edge8\" class=\"edge\">\n<title>1660817530192backward&#45;&gt;1660817533200backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M187.88,-333.99C187.88,-313.89 187.88,-287.46 187.88,-270.33\"/>\n</g>\n<!-- 1660817528272backward -->\n<g id=\"node13\" class=\"node\">\n<title>1660817528272backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M247.63,-223C247.63,-223 286.13,-223 286.13,-223 292.13,-223 298.13,-229 298.13,-235 298.13,-235 298.13,-268.75 298.13,-268.75 298.13,-274.75 292.13,-280.75 286.13,-280.75 286.13,-280.75 247.63,-280.75 247.63,-280.75 241.63,-280.75 235.63,-274.75 235.63,-268.75 235.63,-268.75 235.63,-235 235.63,-235 235.63,-229 241.63,-223 247.63,-223\"/>\n<text text-anchor=\"middle\" x=\"266.88\" y=\"-267.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"235.63,-261.5 298.13,-261.5\"/>\n<text text-anchor=\"middle\" x=\"266.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"235.63,-242.25 298.13,-242.25\"/>\n<text text-anchor=\"middle\" x=\"266.88\" y=\"-228.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">+</text>\n</g>\n<!-- 1660817530192backward&#45;&gt;1660817528272backward -->\n<g id=\"edge9\" class=\"edge\">\n<title>1660817530192backward&#45;&gt;1660817528272backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M208.23,-333.99C217.9,-320.61 229.6,-304.43 240,-290.05\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"242.69,-292.3 245.72,-282.14 237.02,-288.19 242.69,-292.3\"/>\n<text text-anchor=\"middle\" x=\"246.36\" y=\"-302.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n<!-- 1660817528272backward&#45;&gt;1660817533392backward -->\n<g id=\"edge13\" class=\"edge\">\n<title>1660817528272backward&#45;&gt;1660817533392backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M267.14,-222.74C267.25,-210 267.39,-194.72 267.52,-180.87\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"271.02,-181.21 267.61,-171.18 264.02,-181.15 271.02,-181.21\"/>\n<text text-anchor=\"middle\" x=\"281.69\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n<!-- 1660817528272backward&#45;&gt;1660817523472backward -->\n<g id=\"edge11\" class=\"edge\">\n<title>1660817528272backward&#45;&gt;1660817523472backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M246.28,-222.74C236.48,-209.36 224.63,-193.18 214.11,-178.8\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"217.04,-176.89 208.31,-170.89 211.39,-181.02 217.04,-176.89\"/>\n<text text-anchor=\"middle\" x=\"245.92\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x182b0650ad0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "现在我来简单解释一下这个图片，通过第 10 节我们知道，这个图是反向传播的展示图，这个图主要分为两部分： $x1$ 、$y1$ 到 mse 和 $x2$ 、 $y2$ 到 mse，我们选其一解释。\n",
    "\n",
    "正向传播是从下到上看的，我们这个模型是一个线性模型，其公式为： $ y = ax + b $ 。刚开始，我们输入了 $x1$ ， $x1$ 与图中的 $a$ 做乘法运算形成新节点，我们称呼这个节点为 $ax_1$ ，随后 $ax_1$ 与 $b$ 做加法运算产生新节点，这个节点我们叫做 $ ax_1 + b $ ，已是模型预测输出的结果值，但是我们要训练模型，要计算预测值和真实值 $y1$ 的误差，于是将这个节点的值与 $y1$ 做减法运算给到 mse 。mse 要做的是计算目前批次输入的所有 x 值与真实值的误差，于是同理 $x2$ 最终的计算结果页会给到 mse 。\n",
    "\n",
    "\n",
    "代码中，mse 损失器的计算公式如下：\n",
    "\n",
    "$MSE(a, b) = \\frac{1}{n} \\sum_{i=1}^n (y_i - ax_i - b)^2$\n",
    "\n",
    "代码是这样书写的：\n",
    "\n",
    "```python\n",
    "for item in errors:\n",
    "    value += item.value ** 2 / n\n",
    "    wrt[item] = 2 / n * item.value\n",
    "    requires_grad = requires_grad or item.requires_grad\n",
    "```\n",
    "\n",
    "`value` 是最终所有差值累加的结果，`wrt` 是保存该函数对于每个输入的偏导，不难可以推导出其偏导：\n",
    "\n",
    "$\n",
    "\\begin{aligned}\n",
    "\\text{mse} &= \\frac{1}{n} \\sum_{i=1}^{n} \\text{loss}_i^2 = \\frac{1}{n} \\left( \\text{loss}_1^2 + \\text{loss}_2^2 + \\text{loss}_3^2 + \\cdots + \\text{loss}_n^2 \\right) \\\\\n",
    "\\frac{\\partial \\text{mse}}{\\partial \\text{loss}_i} &= \\frac{1}{n} \\cdot 2 \\cdot \\text{loss}_i\n",
    "\\end{aligned}\n",
    "$\n",
    "\n",
    "所以从上图我们会发现一个问题，当输入的批次内的样本输入的样本数量非常大的时候，就会出现计算图膨胀的问题！这十分消耗内存。"
   ],
   "id": "15f898f14c50de18"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T07:04:27.962106Z",
     "start_time": "2025-08-12T07:04:27.189880Z"
    }
   },
   "cell_type": "code",
   "source": [
    "### 梯度积累的计算：将整个计算图拆成若干个计算图进行计算，并累计梯度，这可以解决计算图膨胀的问题\n",
    "model = Linear()\n",
    "\n",
    "loss = mse([model.error(x1, y1)])\n",
    "loss.backward()\n",
    "\n",
    "draw_graph(loss, 'backward')"
   ],
   "id": "d9eb3de1ddbb3aa9",
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.1.0 (20240811.2233)\n -->\n<!-- Pages: 1 -->\n<svg width=\"189pt\" height=\"512pt\"\n viewBox=\"0.00 0.00 189.13 511.75\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 507.75)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-507.75 185.13,-507.75 185.13,4 -4,4\"/>\n<!-- 1660818477136backward -->\n<g id=\"node1\" class=\"node\">\n<title>1660818477136backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M49.63,-111.75C49.63,-111.75 88.13,-111.75 88.13,-111.75 94.13,-111.75 100.13,-117.75 100.13,-123.75 100.13,-123.75 100.13,-157.5 100.13,-157.5 100.13,-163.5 94.13,-169.5 88.13,-169.5 88.13,-169.5 49.63,-169.5 49.63,-169.5 43.63,-169.5 37.63,-163.5 37.63,-157.5 37.63,-157.5 37.63,-123.75 37.63,-123.75 37.63,-117.75 43.63,-111.75 49.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;2.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"37.63,-150.25 100.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"37.63,-131 100.13,-131\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660817709648backward -->\n<g id=\"node2\" class=\"node\">\n<title>1660817709648backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"29.88\" cy=\"-29.38\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"29.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">x1=1.50</text>\n</g>\n<!-- 1660818477136backward&#45;&gt;1660817709648backward -->\n<g id=\"edge2\" class=\"edge\">\n<title>1660818477136backward&#45;&gt;1660817709648backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M58.84,-111.49C51.57,-91.13 41.99,-64.3 35.88,-47.19\"/>\n</g>\n<!-- 1660818479440backward -->\n<g id=\"node6\" class=\"node\">\n<title>1660818479440backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-0.5 77.63,-58.25 140.13,-58.25 140.13,-0.5 77.63,-0.5\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-44.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;3.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-39 140.13,-39\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-19.75 140.13,-19.75\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-6.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">a</text>\n</g>\n<!-- 1660818477136backward&#45;&gt;1660818479440backward -->\n<g id=\"edge4\" class=\"edge\">\n<title>1660818477136backward&#45;&gt;1660818479440backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M79.18,-111.49C83.89,-98.62 89.55,-83.16 94.66,-69.21\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"97.94,-70.43 98.09,-59.84 91.37,-68.02 97.94,-70.43\"/>\n<text text-anchor=\"middle\" x=\"105.53\" y=\"-79.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;3.00</text>\n</g>\n<!-- 1660817523984backward -->\n<g id=\"node3\" class=\"node\">\n<title>1660817523984backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"37.88\" cy=\"-251.88\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"37.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">y1=1.00</text>\n</g>\n<!-- 1660816630032backward -->\n<g id=\"node4\" class=\"node\">\n<title>1660816630032backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M34.63,-445.5C34.63,-445.5 73.13,-445.5 73.13,-445.5 79.13,-445.5 85.13,-451.5 85.13,-457.5 85.13,-457.5 85.13,-491.25 85.13,-491.25 85.13,-497.25 79.13,-503.25 73.13,-503.25 73.13,-503.25 34.63,-503.25 34.63,-503.25 28.63,-503.25 22.63,-497.25 22.63,-491.25 22.63,-491.25 22.63,-457.5 22.63,-457.5 22.63,-451.5 28.63,-445.5 34.63,-445.5\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-489.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"22.63,-484 85.13,-484\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"22.63,-464.75 85.13,-464.75\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-451.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">mse</text>\n</g>\n<!-- 1660816632272backward -->\n<g id=\"node8\" class=\"node\">\n<title>1660816632272backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M34.63,-334.25C34.63,-334.25 73.13,-334.25 73.13,-334.25 79.13,-334.25 85.13,-340.25 85.13,-346.25 85.13,-346.25 85.13,-380 85.13,-380 85.13,-386 79.13,-392 73.13,-392 73.13,-392 34.63,-392 34.63,-392 28.63,-392 22.63,-386 22.63,-380 22.63,-380 22.63,-346.25 22.63,-346.25 22.63,-340.25 28.63,-334.25 34.63,-334.25\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-378.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=2.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"22.63,-372.75 85.13,-372.75\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-359.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"22.63,-353.5 85.13,-353.5\"/>\n<text text-anchor=\"middle\" x=\"53.88\" y=\"-340\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">&#45;</text>\n</g>\n<!-- 1660816630032backward&#45;&gt;1660816632272backward -->\n<g id=\"edge3\" class=\"edge\">\n<title>1660816630032backward&#45;&gt;1660816632272backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M53.88,-445.24C53.88,-432.5 53.88,-417.22 53.88,-403.37\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"57.38,-403.68 53.88,-393.68 50.38,-403.68 57.38,-403.68\"/>\n<text text-anchor=\"middle\" x=\"65.88\" y=\"-413.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">2.00</text>\n</g>\n<!-- 1660816637200backward -->\n<g id=\"node5\" class=\"node\">\n<title>1660816637200backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M97.63,-223C97.63,-223 136.13,-223 136.13,-223 142.13,-223 148.13,-229 148.13,-235 148.13,-235 148.13,-268.75 148.13,-268.75 148.13,-274.75 142.13,-280.75 136.13,-280.75 136.13,-280.75 97.63,-280.75 97.63,-280.75 91.63,-280.75 85.63,-274.75 85.63,-268.75 85.63,-268.75 85.63,-235 85.63,-235 85.63,-229 91.63,-223 97.63,-223\"/>\n<text text-anchor=\"middle\" x=\"116.88\" y=\"-267.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;2.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"85.63,-261.5 148.13,-261.5\"/>\n<text text-anchor=\"middle\" x=\"116.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"85.63,-242.25 148.13,-242.25\"/>\n<text text-anchor=\"middle\" x=\"116.88\" y=\"-228.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">+</text>\n</g>\n<!-- 1660816637200backward&#45;&gt;1660818477136backward -->\n<g id=\"edge1\" class=\"edge\">\n<title>1660816637200backward&#45;&gt;1660818477136backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M104.52,-222.74C98.81,-209.74 91.94,-194.1 85.76,-180.04\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"89.04,-178.8 81.81,-171.05 82.63,-181.61 89.04,-178.8\"/>\n<text text-anchor=\"middle\" x=\"110\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;2.00</text>\n</g>\n<!-- 1660818476880backward -->\n<g id=\"node7\" class=\"node\">\n<title>1660818476880backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-111.75 118.63,-169.5 181.13,-169.5 181.13,-111.75 118.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;2.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-150.25 181.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-131 181.13,-131\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">b</text>\n</g>\n<!-- 1660816637200backward&#45;&gt;1660818476880backward -->\n<g id=\"edge6\" class=\"edge\">\n<title>1660816637200backward&#45;&gt;1660818476880backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M125.38,-222.74C129.27,-209.87 133.94,-194.41 138.15,-180.46\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"141.43,-181.7 140.97,-171.11 134.73,-179.67 141.43,-181.7\"/>\n<text text-anchor=\"middle\" x=\"149.61\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;2.00</text>\n</g>\n<!-- 1660816632272backward&#45;&gt;1660817523984backward -->\n<g id=\"edge7\" class=\"edge\">\n<title>1660816632272backward&#45;&gt;1660817523984backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M49.76,-333.99C46.8,-313.76 42.9,-287.13 40.39,-270.01\"/>\n</g>\n<!-- 1660816632272backward&#45;&gt;1660816637200backward -->\n<g id=\"edge5\" class=\"edge\">\n<title>1660816632272backward&#45;&gt;1660816637200backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M70.1,-333.99C77.67,-320.86 86.8,-305.05 94.97,-290.88\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"97.99,-292.64 99.96,-282.23 91.93,-289.14 97.99,-292.64\"/>\n<text text-anchor=\"middle\" x=\"103.4\" y=\"-302.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;2.00</text>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x182b06289d0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "观察上述的图像，我们发现最终给到 $a$ 的梯度是 $-3$ 而不是我们之前计算图中的 $-1.5$ ，这是因为 mse 损失器中的导数除了一个 $n$ ，想要正确累积梯度，需要进行额外的加工计算，这我们使用的 mse 损失器比较简单，所以直接除一个样本数量即可。",
   "id": "ef54064e0c49d1b7"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T07:10:40.352620Z",
     "start_time": "2025-08-12T07:10:39.545642Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model = Linear()\n",
    "\n",
    "loss = 0.5 * mse([model.error(x1, y1)])  # 这里乘了 0.5 代表除了样本数量\n",
    "loss.backward()\n",
    "\n",
    "draw_graph(loss, 'backward')"
   ],
   "id": "285819c29a63e6fe",
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.1.0 (20240811.2233)\n -->\n<!-- Pages: 1 -->\n<svg width=\"227pt\" height=\"623pt\"\n viewBox=\"0.00 0.00 226.76 623.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 619)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-619 222.76,-619 222.76,4 -4,4\"/>\n<!-- 1660817936400backward -->\n<g id=\"node1\" class=\"node\">\n<title>1660817936400backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M90.63,-223C90.63,-223 129.13,-223 129.13,-223 135.13,-223 141.13,-229 141.13,-235 141.13,-235 141.13,-268.75 141.13,-268.75 141.13,-274.75 135.13,-280.75 129.13,-280.75 129.13,-280.75 90.63,-280.75 90.63,-280.75 84.63,-280.75 78.63,-274.75 78.63,-268.75 78.63,-268.75 78.63,-235 78.63,-235 78.63,-229 84.63,-223 90.63,-223\"/>\n<text text-anchor=\"middle\" x=\"109.88\" y=\"-267.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"78.63,-261.5 141.13,-261.5\"/>\n<text text-anchor=\"middle\" x=\"109.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"78.63,-242.25 141.13,-242.25\"/>\n<text text-anchor=\"middle\" x=\"109.88\" y=\"-228.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">+</text>\n</g>\n<!-- 1660818474576backward -->\n<g id=\"node3\" class=\"node\">\n<title>1660818474576backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M49.63,-111.75C49.63,-111.75 88.13,-111.75 88.13,-111.75 94.13,-111.75 100.13,-117.75 100.13,-123.75 100.13,-123.75 100.13,-157.5 100.13,-157.5 100.13,-163.5 94.13,-169.5 88.13,-169.5 88.13,-169.5 49.63,-169.5 49.63,-169.5 43.63,-169.5 37.63,-163.5 37.63,-157.5 37.63,-157.5 37.63,-123.75 37.63,-123.75 37.63,-117.75 43.63,-111.75 49.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"37.63,-150.25 100.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"37.63,-131 100.13,-131\"/>\n<text text-anchor=\"middle\" x=\"68.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660817936400backward&#45;&gt;1660818474576backward -->\n<g id=\"edge4\" class=\"edge\">\n<title>1660817936400backward&#45;&gt;1660818474576backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M99.32,-222.74C94.44,-209.74 88.58,-194.1 83.3,-180.04\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"86.73,-179.21 79.94,-171.08 80.17,-181.67 86.73,-179.21\"/>\n<text text-anchor=\"middle\" x=\"106.09\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660817946064backward -->\n<g id=\"node6\" class=\"node\">\n<title>1660817946064backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-111.75 118.63,-169.5 181.13,-169.5 181.13,-111.75 118.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-150.25 181.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"118.63,-131 181.13,-131\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">b</text>\n</g>\n<!-- 1660817936400backward&#45;&gt;1660817946064backward -->\n<g id=\"edge7\" class=\"edge\">\n<title>1660817936400backward&#45;&gt;1660817946064backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M120.18,-222.74C124.89,-209.87 130.55,-194.41 135.66,-180.46\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"138.94,-181.68 139.09,-171.09 132.37,-179.27 138.94,-181.68\"/>\n<text text-anchor=\"middle\" x=\"146.53\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660818569808backward -->\n<g id=\"node2\" class=\"node\">\n<title>1660818569808backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M130.63,-334.25C130.63,-334.25 169.13,-334.25 169.13,-334.25 175.13,-334.25 181.13,-340.25 181.13,-346.25 181.13,-346.25 181.13,-380 181.13,-380 181.13,-386 175.13,-392 169.13,-392 169.13,-392 130.63,-392 130.63,-392 124.63,-392 118.63,-386 118.63,-380 118.63,-380 118.63,-346.25 118.63,-346.25 118.63,-340.25 124.63,-334.25 130.63,-334.25\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-378.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"118.63,-372.75 181.13,-372.75\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-359.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"118.63,-353.5 181.13,-353.5\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-340\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">&#45;</text>\n</g>\n<!-- 1660818569808backward&#45;&gt;1660817936400backward -->\n<g id=\"edge9\" class=\"edge\">\n<title>1660818569808backward&#45;&gt;1660817936400backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M139.58,-333.99C134.87,-321.12 129.21,-305.66 124.1,-291.71\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"127.39,-290.52 120.67,-282.34 120.82,-292.93 127.39,-290.52\"/>\n<text text-anchor=\"middle\" x=\"146.53\" y=\"-302.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.00</text>\n</g>\n<!-- 1660817523984backward -->\n<g id=\"node8\" class=\"node\">\n<title>1660817523984backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"188.88\" cy=\"-251.88\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"188.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">y1=1.00</text>\n</g>\n<!-- 1660818569808backward&#45;&gt;1660817523984backward -->\n<g id=\"edge8\" class=\"edge\">\n<title>1660818569808backward&#45;&gt;1660817523984backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M159.92,-333.99C167.19,-313.63 176.77,-286.8 182.88,-269.69\"/>\n</g>\n<!-- 1660817709648backward -->\n<g id=\"node4\" class=\"node\">\n<title>1660817709648backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"29.88\" cy=\"-29.38\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"29.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">x1=1.50</text>\n</g>\n<!-- 1660818474576backward&#45;&gt;1660817709648backward -->\n<g id=\"edge1\" class=\"edge\">\n<title>1660818474576backward&#45;&gt;1660817709648backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M58.84,-111.49C51.57,-91.13 41.99,-64.3 35.88,-47.19\"/>\n</g>\n<!-- 1660817944528backward -->\n<g id=\"node10\" class=\"node\">\n<title>1660817944528backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-0.5 77.63,-58.25 140.13,-58.25 140.13,-0.5 77.63,-0.5\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-44.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;1.50</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-39 140.13,-39\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"77.63,-19.75 140.13,-19.75\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-6.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">a</text>\n</g>\n<!-- 1660818474576backward&#45;&gt;1660817944528backward -->\n<g id=\"edge5\" class=\"edge\">\n<title>1660818474576backward&#45;&gt;1660817944528backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M79.18,-111.49C83.89,-98.62 89.55,-83.16 94.66,-69.21\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"97.94,-70.43 98.09,-59.84 91.37,-68.02 97.94,-70.43\"/>\n<text text-anchor=\"middle\" x=\"105.53\" y=\"-79.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;1.50</text>\n</g>\n<!-- 1660817938512backward -->\n<g id=\"node5\" class=\"node\">\n<title>1660817938512backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"63.88\" cy=\"-474.38\" rx=\"36.54\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"63.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">input=0.50</text>\n</g>\n<!-- 1660817944272backward -->\n<g id=\"node7\" class=\"node\">\n<title>1660817944272backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M130.63,-445.5C130.63,-445.5 169.13,-445.5 169.13,-445.5 175.13,-445.5 181.13,-451.5 181.13,-457.5 181.13,-457.5 181.13,-491.25 181.13,-491.25 181.13,-497.25 175.13,-503.25 169.13,-503.25 169.13,-503.25 130.63,-503.25 130.63,-503.25 124.63,-503.25 118.63,-497.25 118.63,-491.25 118.63,-491.25 118.63,-457.5 118.63,-457.5 118.63,-451.5 124.63,-445.5 130.63,-445.5\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-489.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=0.50</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"118.63,-484 181.13,-484\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"118.63,-464.75 181.13,-464.75\"/>\n<text text-anchor=\"middle\" x=\"149.88\" y=\"-451.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">mse</text>\n</g>\n<!-- 1660817944272backward&#45;&gt;1660818569808backward -->\n<g id=\"edge3\" class=\"edge\">\n<title>1660817944272backward&#45;&gt;1660818569808backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M149.88,-445.24C149.88,-432.5 149.88,-417.22 149.88,-403.37\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"153.38,-403.68 149.88,-393.68 146.38,-403.68 153.38,-403.68\"/>\n<text text-anchor=\"middle\" x=\"161.88\" y=\"-413.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.00</text>\n</g>\n<!-- 1660817940368backward -->\n<g id=\"node9\" class=\"node\">\n<title>1660817940368backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M60.63,-556.75C60.63,-556.75 99.13,-556.75 99.13,-556.75 105.13,-556.75 111.13,-562.75 111.13,-568.75 111.13,-568.75 111.13,-602.5 111.13,-602.5 111.13,-608.5 105.13,-614.5 99.13,-614.5 99.13,-614.5 60.63,-614.5 60.63,-614.5 54.63,-614.5 48.63,-608.5 48.63,-602.5 48.63,-602.5 48.63,-568.75 48.63,-568.75 48.63,-562.75 54.63,-556.75 60.63,-556.75\"/>\n<text text-anchor=\"middle\" x=\"79.88\" y=\"-601\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"48.63,-595.25 111.13,-595.25\"/>\n<text text-anchor=\"middle\" x=\"79.88\" y=\"-581.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.50</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"48.63,-576 111.13,-576\"/>\n<text text-anchor=\"middle\" x=\"79.88\" y=\"-562.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660817940368backward&#45;&gt;1660817938512backward -->\n<g id=\"edge6\" class=\"edge\">\n<title>1660817940368backward&#45;&gt;1660817938512backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M75.76,-556.49C72.82,-536.39 68.95,-509.96 66.44,-492.83\"/>\n</g>\n<!-- 1660817940368backward&#45;&gt;1660817944272backward -->\n<g id=\"edge2\" class=\"edge\">\n<title>1660817940368backward&#45;&gt;1660817944272backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M97.91,-556.49C106.4,-543.24 116.65,-527.24 125.8,-512.97\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"128.65,-515 131.1,-504.69 122.76,-511.22 128.65,-515\"/>\n<text text-anchor=\"middle\" x=\"131.07\" y=\"-524.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.50</text>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x182b0773110>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T07:11:24.694430Z",
     "start_time": "2025-08-12T07:11:23.951633Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 进行第二次传播（这就解释了为什么 Pytorch 中需要手动清空之前的梯度，每次我们在新的一轮进行训练时总是要手动清空梯度，不清空为的就是梯度积累）\n",
    "loss = 0.5 * mse([model.error(x2, y2)])  # 这里乘了 0.5 代表除了样本数量\n",
    "loss.backward()\n",
    "\n",
    "draw_graph(loss, 'backward')"
   ],
   "id": "cf7bb467cc1814db",
   "outputs": [
    {
     "data": {
      "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 12.1.0 (20240811.2233)\n -->\n<!-- Pages: 1 -->\n<svg width=\"235pt\" height=\"623pt\"\n viewBox=\"0.00 0.00 235.13 623.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 619)\">\n<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-619 231.13,-619 231.13,4 -4,4\"/>\n<!-- 1662060544080backward -->\n<g id=\"node1\" class=\"node\">\n<title>1662060544080backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M24.01,-445.5C24.01,-445.5 67.76,-445.5 67.76,-445.5 73.76,-445.5 79.76,-451.5 79.76,-457.5 79.76,-457.5 79.76,-491.25 79.76,-491.25 79.76,-497.25 73.76,-503.25 67.76,-503.25 67.76,-503.25 24.01,-503.25 24.01,-503.25 18.01,-503.25 12.01,-497.25 12.01,-491.25 12.01,-491.25 12.01,-457.5 12.01,-457.5 12.01,-451.5 18.01,-445.5 24.01,-445.5\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-489.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=0.50</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"12.01,-484 79.76,-484\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=16.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"12.01,-464.75 79.76,-464.75\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-451.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">mse</text>\n</g>\n<!-- 1662060615312backward -->\n<g id=\"node3\" class=\"node\">\n<title>1662060615312backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M26.63,-334.25C26.63,-334.25 65.13,-334.25 65.13,-334.25 71.13,-334.25 77.13,-340.25 77.13,-346.25 77.13,-346.25 77.13,-380 77.13,-380 77.13,-386 71.13,-392 65.13,-392 65.13,-392 26.63,-392 26.63,-392 20.63,-392 14.63,-386 14.63,-380 14.63,-380 14.63,-346.25 14.63,-346.25 14.63,-340.25 20.63,-334.25 26.63,-334.25\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-378.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"14.63,-372.75 77.13,-372.75\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-359.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"14.63,-353.5 77.13,-353.5\"/>\n<text text-anchor=\"middle\" x=\"45.88\" y=\"-340\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">&#45;</text>\n</g>\n<!-- 1662060544080backward&#45;&gt;1662060615312backward -->\n<g id=\"edge6\" class=\"edge\">\n<title>1662060544080backward&#45;&gt;1662060615312backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M45.88,-445.24C45.88,-432.5 45.88,-417.22 45.88,-403.37\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"49.38,-403.68 45.88,-393.68 42.38,-403.68 49.38,-403.68\"/>\n<text text-anchor=\"middle\" x=\"57.88\" y=\"-413.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">4.00</text>\n</g>\n<!-- 1660817531472backward -->\n<g id=\"node2\" class=\"node\">\n<title>1660817531472backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"116.88\" cy=\"-29.38\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"116.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">x2=2.00</text>\n</g>\n<!-- 1660817533200backward -->\n<g id=\"node8\" class=\"node\">\n<title>1660817533200backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"29.88\" cy=\"-251.88\" rx=\"29.88\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"29.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">y2=4.00</text>\n</g>\n<!-- 1662060615312backward&#45;&gt;1660817533200backward -->\n<g id=\"edge1\" class=\"edge\">\n<title>1662060615312backward&#45;&gt;1660817533200backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M41.76,-333.99C38.8,-313.76 34.9,-287.13 32.39,-270.01\"/>\n</g>\n<!-- 1660776374608backward -->\n<g id=\"node9\" class=\"node\">\n<title>1660776374608backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M89.63,-223C89.63,-223 128.13,-223 128.13,-223 134.13,-223 140.13,-229 140.13,-235 140.13,-235 140.13,-268.75 140.13,-268.75 140.13,-274.75 134.13,-280.75 128.13,-280.75 128.13,-280.75 89.63,-280.75 89.63,-280.75 83.63,-280.75 77.63,-274.75 77.63,-268.75 77.63,-268.75 77.63,-235 77.63,-235 77.63,-229 83.63,-223 89.63,-223\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-267.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"77.63,-261.5 140.13,-261.5\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-248\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"77.63,-242.25 140.13,-242.25\"/>\n<text text-anchor=\"middle\" x=\"108.88\" y=\"-228.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">+</text>\n</g>\n<!-- 1662060615312backward&#45;&gt;1660776374608backward -->\n<g id=\"edge7\" class=\"edge\">\n<title>1662060615312backward&#45;&gt;1660776374608backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M62.1,-333.99C69.67,-320.86 78.8,-305.05 86.97,-290.88\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"89.99,-292.64 91.96,-282.23 83.93,-289.14 89.99,-292.64\"/>\n<text text-anchor=\"middle\" x=\"95.4\" y=\"-302.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n<!-- 1660817946064backward -->\n<g id=\"node4\" class=\"node\">\n<title>1660817946064backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"44.63,-111.75 44.63,-169.5 107.13,-169.5 107.13,-111.75 44.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"75.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;5.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"44.63,-150.25 107.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"75.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"44.63,-131 107.13,-131\"/>\n<text text-anchor=\"middle\" x=\"75.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">b</text>\n</g>\n<!-- 1660816664784backward -->\n<g id=\"node5\" class=\"node\">\n<title>1660816664784backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M70.63,-556.75C70.63,-556.75 109.13,-556.75 109.13,-556.75 115.13,-556.75 121.13,-562.75 121.13,-568.75 121.13,-568.75 121.13,-602.5 121.13,-602.5 121.13,-608.5 115.13,-614.5 109.13,-614.5 109.13,-614.5 70.63,-614.5 70.63,-614.5 64.63,-614.5 58.63,-608.5 58.63,-602.5 58.63,-602.5 58.63,-568.75 58.63,-568.75 58.63,-562.75 64.63,-556.75 70.63,-556.75\"/>\n<text text-anchor=\"middle\" x=\"89.88\" y=\"-601\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=1.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"58.63,-595.25 121.13,-595.25\"/>\n<text text-anchor=\"middle\" x=\"89.88\" y=\"-581.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=8.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"58.63,-576 121.13,-576\"/>\n<text text-anchor=\"middle\" x=\"89.88\" y=\"-562.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660816664784backward&#45;&gt;1662060544080backward -->\n<g id=\"edge8\" class=\"edge\">\n<title>1660816664784backward&#45;&gt;1662060544080backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M78.55,-556.49C73.31,-543.49 67.02,-527.85 61.35,-513.79\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"64.72,-512.79 57.74,-504.82 58.23,-515.4 64.72,-512.79\"/>\n<text text-anchor=\"middle\" x=\"82.52\" y=\"-524.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.50</text>\n</g>\n<!-- 1662060558032backward -->\n<g id=\"node7\" class=\"node\">\n<title>1662060558032backward</title>\n<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"133.88\" cy=\"-474.38\" rx=\"36.54\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"133.88\" y=\"-470.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">input=0.50</text>\n</g>\n<!-- 1660816664784backward&#45;&gt;1662060558032backward -->\n<g id=\"edge9\" class=\"edge\">\n<title>1660816664784backward&#45;&gt;1662060558032backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M101.21,-556.49C109.36,-536.26 120.08,-509.63 126.98,-492.51\"/>\n</g>\n<!-- 1660818487504backward -->\n<g id=\"node6\" class=\"node\">\n<title>1660818487504backward</title>\n<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M137.63,-111.75C137.63,-111.75 176.13,-111.75 176.13,-111.75 182.13,-111.75 188.13,-117.75 188.13,-123.75 188.13,-123.75 188.13,-157.5 188.13,-157.5 188.13,-163.5 182.13,-169.5 176.13,-169.5 176.13,-169.5 137.63,-169.5 137.63,-169.5 131.63,-169.5 125.63,-163.5 125.63,-157.5 125.63,-157.5 125.63,-123.75 125.63,-123.75 125.63,-117.75 131.63,-111.75 137.63,-111.75\"/>\n<text text-anchor=\"middle\" x=\"156.88\" y=\"-156\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;4.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"125.63,-150.25 188.13,-150.25\"/>\n<text text-anchor=\"middle\" x=\"156.88\" y=\"-136.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" points=\"125.63,-131 188.13,-131\"/>\n<text text-anchor=\"middle\" x=\"156.88\" y=\"-117.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">*</text>\n</g>\n<!-- 1660818487504backward&#45;&gt;1660817531472backward -->\n<g id=\"edge2\" class=\"edge\">\n<title>1660818487504backward&#45;&gt;1660817531472backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M146.58,-111.49C139.13,-91.13 129.3,-64.3 123.04,-47.19\"/>\n</g>\n<!-- 1660817944528backward -->\n<g id=\"node10\" class=\"node\">\n<title>1660817944528backward</title>\n<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"164.63,-0.5 164.63,-58.25 227.13,-58.25 227.13,-0.5 164.63,-0.5\"/>\n<text text-anchor=\"middle\" x=\"195.88\" y=\"-44.75\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">grad=&#45;9.50</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"164.63,-39 227.13,-39\"/>\n<text text-anchor=\"middle\" x=\"195.88\" y=\"-25.5\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">value=0.00</text>\n<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"164.63,-19.75 227.13,-19.75\"/>\n<text text-anchor=\"middle\" x=\"195.88\" y=\"-6.25\" font-family=\"Times New Roman,serif\" font-size=\"10.00\">a</text>\n</g>\n<!-- 1660818487504backward&#45;&gt;1660817944528backward -->\n<g id=\"edge3\" class=\"edge\">\n<title>1660818487504backward&#45;&gt;1660817944528backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M166.92,-111.49C171.52,-98.62 177.04,-83.16 182.02,-69.21\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"185.3,-70.43 185.36,-59.84 178.7,-68.08 185.3,-70.43\"/>\n<text text-anchor=\"middle\" x=\"192.97\" y=\"-79.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;8.00</text>\n</g>\n<!-- 1660776374608backward&#45;&gt;1660817946064backward -->\n<g id=\"edge5\" class=\"edge\">\n<title>1660776374608backward&#45;&gt;1660817946064backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M97.6,-222.61C95.41,-216.67 93.22,-210.42 91.38,-204.5 89.05,-197 86.84,-188.91 84.85,-181.11\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"88.26,-180.34 82.48,-171.46 81.47,-182.01 88.26,-180.34\"/>\n<text text-anchor=\"middle\" x=\"105.63\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n<!-- 1660776374608backward&#45;&gt;1660818487504backward -->\n<g id=\"edge4\" class=\"edge\">\n<title>1660776374608backward&#45;&gt;1660818487504backward</title>\n<path fill=\"none\" stroke=\"deepskyblue\" d=\"M121.24,-222.74C126.95,-209.74 133.82,-194.1 140,-180.04\"/>\n<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"143.13,-181.61 143.95,-171.05 136.73,-178.8 143.13,-181.61\"/>\n<text text-anchor=\"middle\" x=\"150\" y=\"-191.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;4.00</text>\n</g>\n</g>\n</svg>\n",
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x182b0630810>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-12T07:19:11.508934Z",
     "start_time": "2025-08-12T07:19:11.416326Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 梯度积累的训练方式\n",
    "model = Linear()\n",
    "\n",
    "batch_size = 20\n",
    "learning_rate = 0.1\n",
    "\n",
    "gradient_accu_iter = 4  # 累计的次数\n",
    "\n",
    "# 写的代码和视频里不一样，可能有错误，要保证梯度累计的次数是批量大小的因数\n",
    "\n",
    "for t in range(20):\n",
    "    ix = (t * batch_size) % len(x)\n",
    "    xx = x[ix: ix + batch_size]\n",
    "    yy = y[ix: ix + batch_size]\n",
    "\n",
    "    for gt in range(gradient_accu_iter):\n",
    "        xxx = xx[gt * gradient_accu_iter: (gt + 1) * gradient_accu_iter]\n",
    "        yyy = yy[gt * gradient_accu_iter: (gt + 1) * gradient_accu_iter]\n",
    "        loss = mse([model.error(_x, _y) for _x, _y in zip(xxx, yyy)])\n",
    "        loss = loss * (1 / gradient_accu_iter)\n",
    "        loss.backward()\n",
    "\n",
    "    model.a -= learning_rate * model.a.grad\n",
    "    model.b -= learning_rate * model.b.grad\n",
    "\n",
    "    model.a.grad = 0.0\n",
    "    model.b.grad = 0.0\n",
    "\n",
    "    print(model.string())"
   ],
   "id": "2c033362800fbc7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y = 3.33 * x + -2.09\n",
      "y = 3.64 * x + -2.33\n",
      "y = 3.37 * x + -2.02\n",
      "y = 2.97 * x + -1.28\n",
      "y = 2.77 * x + -0.27\n",
      "y = 2.95 * x + 1.00\n",
      "y = 3.68 * x + 2.47\n",
      "y = 4.98 * x + 4.04\n",
      "y = 6.61 * x + 5.42\n",
      "y = 8.18 * x + 6.46\n",
      "y = 9.42 * x + 5.68\n",
      "y = 9.76 * x + 5.40\n",
      "y = 9.89 * x + 5.26\n",
      "y = 9.89 * x + 5.27\n",
      "y = 9.89 * x + 5.26\n",
      "y = 9.88 * x + 5.22\n",
      "y = 9.87 * x + 5.18\n",
      "y = 9.86 * x + 5.18\n",
      "y = 9.87 * x + 5.19\n",
      "y = 10.00 * x + 5.28\n"
     ]
    }
   ],
   "execution_count": 11
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 参数冻结和随机失活\n",
    "\n",
    "参数冻结通常是为了在训练过程中防止特定层的权重被更新。这在迁移学习或微调预训练模型时尤其有用，当我们只想训练模型的一部分时，可以通过冻结其他部分来加速训练过程并减少过拟合的风险。在 Pytorch 中就是把 requires_grad 设置为 False 。\n",
    "\n",
    "随机失活（dropout）是对具有深度结构的人工神经网络进行优化的方法，在学习过程中通过将隐含层的部分权重或输出随机归零，降低节点间的相互依赖性从而实现神经网络的正则化，降低其结构风险。做法是对某些节点的权重乘以 0 ，使其传播时梯度就为 0 了。\n"
   ],
   "id": "3b931da9fddec2a3"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
